
import numpy as np
import tqdm
import os


from src.imports.import_profile import import_profile, import_points



from src.clustering.base_distances_bis import (
    norm_hamming,
    geom_hamming,
    jaccard,
    rms_jaccard,
    geom_jaccard,
    rms_hamming
)

base_distances = {
    'norm_hamming' : norm_hamming,
    'geom_hamming' : geom_hamming,
    'rms_hamming': rms_hamming,
    'jaccard' : jaccard,
    'rms_jaccard' : rms_jaccard,
    'geom_jaccard' : geom_jaccard,
}



if __name__ == "__main__":

    num_tests = 100
    num_voters = 100
    num_candidates = 100

    pairs = [(0.1, 0.1), (0.015, 0.15)]

    for lower_radius, upper_radius in pairs:

        rules = list(base_distances.keys())


        ALL_INTRA = {method: [] for method in rules}
        ALL_INTER = {method: [] for method in rules}
        ALL_CLOSEST = {method: [] for method in rules}

        pcc_avg = {method: 0 for method in rules}

        # for t in tqdm.tqdm(range(num_tests)):
        total_error_ctr = {method: [] for method in rules}
        # tqdm over the number of tests
        for t in tqdm.tqdm(range(num_tests,)):
            # print(t)

            file_path = f'data/sampled/euclidean/profile_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}'
            file_path = os.path.join(os.getcwd(), file_path)
            P = import_profile(file_path)

            v_path = f'data/sampled/euclidean/v_points_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}.csv'
            v_points = import_points(v_path)

            c_path = f'data/sampled/euclidean/c_points_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}.csv'
            c_points = import_points(c_path)

            num_v_clusters = num_voters

            V_sets = {}
            C_sets = {}

            for method in rules:
                # print(method)

                D1 = []
                D2 = []

                correct = []

                for i in range(num_voters):
                    for j in range(i+1, num_voters):
                        for k in range(j+1, num_voters):
                            a,b,c = sorted([i, j, k], key=lambda x: v_points[x])
                            d1 = base_distances[method](P[a], P[c])
                            d3 = base_distances[method](P[a], P[b])
                            d2 = base_distances[method](P[b], P[c])
                            if d1 < d2:
                                correct.append(0)
                            else:
                                correct.append(1)

                            if d1 < d3:
                                correct.append(0)
                            else:
                                correct.append(1)
                total_error_ctr[method].append(sum(correct)/len(correct))


        for method in rules:
            print(f"{method}:  {round(np.mean(total_error_ctr[method]),3)} {round(np.std(total_error_ctr[method]),3)}")
